package com.zhuodewen.ai.controller.langchain4j;

import com.zhuodewen.ai.base.JSONResult;
import com.zhuodewen.ai.config.ai.LangChain4JProvider;
import com.zhuodewen.ai.constant.CommonConstants;
import com.zhuodewen.ai.dto.langchain4j.LangChain4JTool;
import com.zhuodewen.ai.service.langchain4j.LangChain4JAiService;
import com.zhuodewen.ai.service.langchain4j.LangChain4JAssistant;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.openai.OpenAiImageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;

import java.util.List;

/**
 * Controller : 控制器类
 * 用于对请求的内容、响应的内容进行数据格式转换。
 */
@RestController                             //等于@ResponseBody(返回JSON格式的数据) + @Controller(定义为Controller接口类)
@RequestMapping(value = "langchain4j")      //路径映射
@Slf4j                                      //日志
public class LangChain4JController {

    @Autowired
    LangChain4JProvider langChain4JProvider;

    /**
     * deepseek对话
     *
     * @param message
     * @return
     */
    @PostMapping("deepseek/chat")
    public JSONResult deepSeekChat(@RequestParam(value = "message") String message) {
        String callback = langChain4JProvider.getOpenAiChatModel().chat(message);
        return new JSONResult().markSuccess(CommonConstants.RESULT_SUCCESS_MSG, callback);
    }

    /**
     * deepseek对话-流式响应
     *
     * @param message
     * @return
     */
    @PostMapping(value = "deepseek/streamingChat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<String> deepSeekStreamingChat(@RequestParam(value = "message") String message) {
        return Flux.create(fluxSink -> {
            langChain4JProvider.getOpenAiStreamingChatModel().chat(message, new StreamingChatResponseHandler() {
                @Override
                public void onPartialResponse(String callback) {
                    // 将部分响应写入 Flux
                    fluxSink.next(callback);
                }

                @Override
                public void onCompleteResponse(ChatResponse chatResponse) {
                    // 处理完成响应
                    System.out.println("onCompleteResponse: " + chatResponse);
                    fluxSink.complete();
                }

                @Override
                public void onError(Throwable throwable) {
                    fluxSink.error(throwable);
                }
            });
        });
    }

    /**
     * ollama对话(本地模型)
     *
     * @param message
     * @return
     */
    @PostMapping("ollama/chat")
    public JSONResult ollamaChat(@RequestParam(value = "message") String message) {
        String callback = langChain4JProvider.getOllamaChatModel().chat(message);
        return new JSONResult().markSuccess(CommonConstants.RESULT_SUCCESS_MSG, callback);
    }

    /**
     * deepseek对话(本地模型)-流式响应
     *
     * @param message
     * @return
     */
    @PostMapping(value = "ollama/streamingChat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<String> ollamaStreamingChat(@RequestParam(value = "message") String message) {
        return Flux.create(fluxSink -> {
            langChain4JProvider.getOllamaStreamingChatModel().chat(message, new StreamingChatResponseHandler() {
                @Override
                public void onPartialResponse(String callback) {
                    // 将部分响应写入 Flux
                    fluxSink.next(callback);
                }

                @Override
                public void onCompleteResponse(ChatResponse chatResponse) {
                    // 处理完成响应
                    System.out.println("onCompleteResponse: " + chatResponse);
                    fluxSink.complete();
                }

                @Override
                public void onError(Throwable throwable) {
                    fluxSink.error(throwable);
                }
            });
        });
    }

    /**
     * 加载知识库-使用本地内存方式
     *
     * @param message
     * @return
     */
    @PostMapping("ragTest")
    public JSONResult ragTest(@RequestParam(value = "message") String message) {
        // 1.加载知识库
        // [1]从当前目录下加载知识库
        // List<Document> documents = FileSystemDocumentLoader.loadDocuments("/home/langchain4j/documentation");
        List<Document> documents = FileSystemDocumentLoader.loadDocuments("D:/卓德文/其他/AI知识库/知识库");
        // [2]从当前目录下所有子目录加载知识库
        // List<Document> documents = FileSystemDocumentLoader.loadDocumentsRecursively("D:/卓德文/其他/AI知识库/知识库");
        // [3]过滤文档
        // PathMatcher    pathMatcher = FileSystems.getDefault().getPathMatcher("glob:*.pdf");
        // List<Document> documents   = FileSystemDocumentLoader.loadDocuments("D:/卓德文/其他/AI知识库/知识库", pathMatcher);

        // 2.将知识库文件存入向量数据库
        // [1]直接解析
        InMemoryEmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
        EmbeddingStoreIngestor.ingest(documents, embeddingStore);
        // [2]自定义解析(ollama本地Embedding模型) -> 报错：向量长度不一致，待解决
        // EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
        //         .embeddingModel(langChain4JProvider.getOllamaEmbeddingModel())
        //         .embeddingStore(embeddingStore)
        //         .build();
        // ingestor.ingest(documents);


        // 3.创建Assistant和知识库检索器
        LangChain4JAssistant langChain4JAssistant = AiServices.builder(LangChain4JAssistant.class)
                .chatLanguageModel(langChain4JProvider.getOllamaChatModel())
                .chatMemory(MessageWindowChatMemory.withMaxMessages(10))
                .contentRetriever(EmbeddingStoreContentRetriever.from(embeddingStore))
                .build();

        // 4.知识库检索
        String chat = langChain4JAssistant.chat(message);
        log.info("知识库检索内容:{},检索结果:{}", message, chat);
        return new JSONResult().markSuccess(CommonConstants.RESULT_SUCCESS_MSG, chat);
    }

    /**
     * 工具调用
     *
     * @param message
     * @return
     */
    @PostMapping("toolTest")
    public JSONResult toolTest(@RequestParam(value = "message") String message) {
        LangChain4JAiService langChain4JAiService = AiServices.builder(LangChain4JAiService.class)
                .chatLanguageModel(langChain4JProvider.getOpenAiChatModel())
                .tools(new LangChain4JTool())
                .build();
        String answer = langChain4JAiService.ask(message);
        return new JSONResult().markSuccess(CommonConstants.RESULT_SUCCESS_MSG, answer);
    }

    /**
     * 多模态-图片生成(暂不支持国内常见的多模态)
     *
     * @param message
     * @return
     */
    @PostMapping("imageTest")
    public JSONResult imageTest(@RequestParam(value = "message") String message) {
        OpenAiImageModel openAiImageModel = langChain4JProvider.getOpenAiImageModel();
        Response<Image>  response         = openAiImageModel.generate(message);
        log.info("图片生成结果:{}", response);
        return new JSONResult().markSuccess(CommonConstants.RESULT_SUCCESS_MSG, response.content().url());
    }


}
